{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "| [09_dialogue/01_对话模型.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/09_dialogue/01_对话模型.ipynb)  | 基于transformers的Bert问答模型  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/09_dialogue/01_对话模型.ipynb) |\n",
    "\n",
    "\n",
    "# 对话模型\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用transformers的Bert模型完成阅读理解任务。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\r\n",
      "Requirement already satisfied: transformers in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (4.11.3)\r\n",
      "Requirement already satisfied: requests in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (2.25.1)\r\n",
      "Requirement already satisfied: packaging>=20.0 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (20.9)\r\n",
      "Requirement already satisfied: huggingface-hub>=0.0.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.0.19)\r\n",
      "Requirement already satisfied: sacremoses in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.0.45)\r\n",
      "Requirement already satisfied: pyyaml>=5.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (5.4.1)\r\n",
      "Requirement already satisfied: filelock in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (3.0.12)\r\n",
      "Requirement already satisfied: tokenizers<0.11,>=0.10.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (0.10.3)\r\n",
      "Requirement already satisfied: tqdm>=4.27 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (4.59.0)\r\n",
      "Requirement already satisfied: regex!=2019.12.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (2021.4.4)\r\n",
      "Requirement already satisfied: numpy>=1.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from transformers) (1.20.1)\r\n",
      "Requirement already satisfied: typing-extensions in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from huggingface-hub>=0.0.17->transformers) (3.7.4.3)\r\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from packaging>=20.0->transformers) (2.4.7)\r\n",
      "Requirement already satisfied: chardet<5,>=3.0.2 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (4.0.0)\r\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (2021.5.30)\r\n",
      "Requirement already satisfied: idna<3,>=2.5 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (2.10)\r\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from requests->transformers) (1.26.4)\r\n",
      "Requirement already satisfied: click in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (7.1.2)\r\n",
      "Requirement already satisfied: six in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (1.15.0)\r\n",
      "Requirement already satisfied: joblib in /Users/xuming/opt/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers) (1.0.1)\r\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer: '当当网', score: 0.6132, start: 14, end: 17\n",
      "Answer: 'imagenet数据集', score: 0.6238, start: 47, end: 58\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from transformers import pipeline\n",
    "from transformers import AutoModelForQuestionAnswering, BertTokenizer\n",
    "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
    "bert_model = 'luhua/chinese_pretrain_mrc_macbert_large'\n",
    "print(bert_model)\n",
    "# model = AutoModelForQuestionAnswering.from_pretrained(bert_model_dir)\n",
    "# tokenizer = BertTokenizer.from_pretrained(bert_model_dir)\n",
    "nlp = pipeline(\"question-answering\",\n",
    "               model=bert_model,\n",
    "               tokenizer=bert_model,\n",
    "               device=-1,  # gpu device id\n",
    "               )\n",
    "context = r\"\"\"\n",
    "大家好，我是张亮，目前任职当当网架构部架构师一职，也是高可用架构群的一员。我为大家提供了一份imagenet数据集，希望能够为图像分类任务做点贡献。\n",
    "\"\"\"\n",
    "\n",
    "# context = ' '.join(list(context))\n",
    "\n",
    "result = nlp(question=\"张亮在哪里任职?\", context=context)\n",
    "print(\n",
    "    f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")\n",
    "result = nlp(question=\"张亮为图像分类提供了什么数据集?\", context=context)\n",
    "print(\n",
    "    f\"Answer: '{result['answer']}', score: {round(result['score'], 4)}, start: {result['start']}, end: {result['end']}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# Custom Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: 张亮在哪里上班?\n",
      "Answer: 腾 讯\n",
      "Question: 张亮有啥数据?\n",
      "Answer: imagenet 数 据 集\n",
      "Question: 谁提供了imagenet数据集?\n",
      "Answer: 张 明 亮\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering\n",
    "import torch\n",
    "bert_model = 'luhua/chinese_pretrain_mrc_macbert_large'\n",
    "# tokenizer = AutoTokenizer.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\")\n",
    "# model = AutoModelForQuestionAnswering.from_pretrained(\"bert-large-uncased-whole-word-masking-finetuned-squad\")\n",
    "model = AutoModelForQuestionAnswering.from_pretrained(bert_model)\n",
    "tokenizer = BertTokenizer.from_pretrained(bert_model)\n",
    "\n",
    "text = r\"\"\"\n",
    "大家好，我是张明亮，目前任职腾讯架构部架构师一职，也是高可用架构群的一员。我为大家提供了一份imagenet数据集，希望能够为图像分类任务做点贡献\n",
    "\"\"\"\n",
    "questions = [\n",
    "    \"张亮在哪里上班?\",\n",
    "    \"张亮有啥数据?\",\n",
    "    \"谁提供了imagenet数据集?\",\n",
    "]\n",
    "for question in questions:\n",
    "    inputs = tokenizer(question, text, add_special_tokens=True, return_tensors=\"pt\")\n",
    "    #print(f'inputs:{inputs}')\n",
    "    input_ids = inputs[\"input_ids\"].tolist()[0]\n",
    "    outputs = model(**inputs)\n",
    "    #print(outputs)\n",
    "    answer_start_scores = outputs.start_logits\n",
    "    answer_end_scores = outputs.end_logits\n",
    "    answer_start = torch.argmax(\n",
    "        answer_start_scores\n",
    "    )  # Get the most likely beginning of answer with the argmax of the score\n",
    "    answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score\n",
    "    answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))\n",
    "    print(f\"Question: {question}\")\n",
    "    print(f\"Answer: {answer}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}